import copy
import random

import torch

from tools.utils import dim_action_space, preprocess, soft_update


class ContSAC():
    """
    Standard soft actor critic with goal concatenation
    """
    def __init__(self,args,action_space,critic,actor,context,critic2=None):
        self.args=args
        self.context=context
        self.action_space=action_space
        self.act_dim = dim_action_space(action_space)
        self.critic=critic
        self.critic2=critic2
        if args.double_critic:
            self.target_critic2=copy.deepcopy(critic2)
            self.optimizer_critic2 = torch.optim.Adam(self.critic2.parameters(), lr=args.lr_sac,weight_decay=args.weight_decay2)

        self.target_critic = copy.deepcopy(critic)
        self.optimizer_critic=torch.optim.Adam(self.critic.parameters(),lr=args.lr_sac,weight_decay=args.weight_decay2)
        self.actor=actor
        self.optimizer_actor=torch.optim.Adam(self.actor.parameters(),lr=args.lr_sac,weight_decay=args.weight_decay2)

    def act(self, obs, goals,features=None,random_warmup=False,deterministic=False,**kwargs):
        if random_warmup:
            return torch.from_numpy(self.action_space.sample()).unsqueeze(0).to(obs.device)
        if random.random() < self.args.pi_epsilon and not deterministic:
            return torch.tensor([self.action_space.sample() for _ in range(obs.shape[0])],device=obs.device)
        goals=goals.to(self.args.device)
        if self.args.image or self.args.embed_sac:
            if features is None:
                state = self.context.estimator.label_embed(preprocess(obs, self.args),act=True)
            else:
                state=features
        else:
            state= obs.to(self.args.device)

        inputs = torch.cat((state, goals.view(obs.shape[0],-1)), 1) if self.args.type == 0 else state
        action,_,_=self.actor(inputs,deterministic=deterministic)
        return action.detach().cpu()

    def build_input(self,obs,goals):
        return  torch.cat((obs, goals.view(obs.shape[0],-1)), 1)

    def get_value(self,batch,**kwargs):
        if self.args.image or self.args.embed_sac:
            state=batch.pi_state_embeddings
        else:
            state=batch.next_obs

        inputs = self.build_input(state, batch.pi_goal_embeddings).detach() if self.args.type == 0 else state
        action, action_log_probs, _ = self.actor(inputs, deterministic=False)
        inputs_critic = torch.cat((inputs, action), dim=1)
        values = self.target_critic(inputs_critic)
        if self.args.double_critic:
            values2 = self.target_critic2(inputs_critic)
            values = torch.min(values, values2)
        values_ent= values - action_log_probs/self.args.alpha
        return values_ent

    def evaluate(self, rollouts,returns,**kwargs):
        batch=rollouts.get_evals()
        obs = batch.obs
        actions=batch.actions
        returns=returns.to(self.args.device)

        if self.args.image or self.args.embed_sac:
            state=batch.pi_prev_state_embeddings
        else:
            state=obs

        goal_use = batch.pi_goal_embeddings
        inputs=self.build_input(state,goal_use).detach() if self.args.type == 0 else state
        inputs_critic = torch.cat((inputs,actions),dim=1)

        ###Critic optimization
        cval_loss,values=self.optimize_critic(inputs_critic,returns.detach())

        ###Actor optimization
        best_action, best_action_log_prob,_ = self.actor(inputs)
        pi_input_critic=torch.cat((inputs.detach(), best_action), dim=1)
        val=self.critic(pi_input_critic)

        if self.args.double_critic:
            val2=self.critic2(pi_input_critic)
            val = torch.min(val,val2)

        self.optimizer_actor.zero_grad()
        pi_loss=(-val + best_action_log_prob / self.args.alpha).mean()
        # print(val[0],best_action_log_prob[0]/self.args.alpha)
        pi_loss.backward() #+ self.entropy_bonus*dist.entropy().sum()
        if self.args.clip_grad_sac:
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.args.clip_grad_sac)
        self.optimizer_actor.step()

        return cval_loss,values,pi_loss

    def optimize_critic(self,inputs_critic,target):
        values = self.critic(inputs_critic)
        self.optimizer_critic.zero_grad()
        cval_loss = torch.nn.functional.smooth_l1_loss(values, target, reduction="mean")
        cval_loss.backward()
        if self.args.clip_grad_sac:
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.args.clip_grad_sac)
        self.optimizer_critic.step()
        soft_update(self.target_critic, self.critic, self.args.tau)

        if self.args.double_critic:
            values2 = self.critic2(inputs_critic)
            self.optimizer_critic2.zero_grad()
            torch.nn.functional.smooth_l1_loss(values2, target, reduction="mean").backward()
            if self.args.clip_grad_sac:
                torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), self.args.clip_grad_sac)
            self.optimizer_critic2.step()
            soft_update(self.target_critic2, self.critic2, self.args.tau)
        return cval_loss.detach(),values.mean().detach()

    def load(self):
        if self.context.load_model_path:
            path = self.context.load_model_path  + "SAC.pt"
            checkpoint=torch.load(path,map_location=torch.device(self.args.device) )
            self.actor.load_state_dict(checkpoint['actor_state_dict'])
            self.optimizer_critic.load_state_dict(checkpoint['critic_optimizer_state_dict'])
            self.optimizer_actor.load_state_dict(checkpoint['actor_optimizer_state_dict'])
            for param_group in self.optimizer_actor.param_groups:
                param_group["lr"]=self.args.lr_sac
            for param_group in self.optimizer_critic.param_groups:
                param_group["lr"]=self.args.lr_sac
            self.critic.load_state_dict(checkpoint['critic_state_dict'])
            self.target_critic.load_state_dict(checkpoint['target'])
            if self.args.double_critic:
                self.optimizer_critic2.load_state_dict(checkpoint['critic2_optimizer_state_dict'])
                self.critic2.load_state_dict(checkpoint['critic2_state_dict'])
                self.target_critic2.load_state_dict(checkpoint['target2'])
                for param_group in self.optimizer_critic2.param_groups:
                    param_group["lr"] = self.args.lr_sac

    def save(self):
        if self.context.save_model:
            path = self.context.path_models+"SAC.pt"
            obj={}
            obj["actor_state_dict"] = self.actor.state_dict()
            obj['critic_optimizer_state_dict']=self.optimizer_critic.state_dict()
            obj['actor_optimizer_state_dict']=self.optimizer_actor.state_dict()
            obj['critic_state_dict']= self.critic.state_dict()
            obj['target']=self.target_critic.state_dict()
            if self.args.double_critic:
                obj['critic2_optimizer_state_dict'] = self.optimizer_critic2.state_dict()
                obj['critic2_state_dict'] = self.critic2.state_dict()
                obj['target2'] = self.target_critic2.state_dict()
            torch.save(obj, path)
